!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief computes preconditioners, and implements methods to apply them
!>      currently used in qs_ot
!> \par History
!>      - [UB] 2009-05-13 Adding stable approximate inverse (full and sparse)
!> \author Joost VandeVondele (09.2002)
! **************************************************************************************************
MODULE preconditioner_makes
   USE arnoldi_api,                     ONLY: arnoldi_env_type,&
                                              arnoldi_ev,&
                                              deallocate_arnoldi_env,&
                                              get_selected_ritz_val,&
                                              get_selected_ritz_vector,&
                                              set_arnoldi_initial_vector,&
                                              setup_arnoldi_env
   USE cp_dbcsr_api,                    ONLY: &
        dbcsr_add, dbcsr_copy, dbcsr_create, dbcsr_get_info, dbcsr_multiply, dbcsr_p_type, &
        dbcsr_release, dbcsr_type, dbcsr_type_symmetric
   USE cp_dbcsr_contrib,                ONLY: dbcsr_add_on_diag
   USE cp_dbcsr_operations,             ONLY: copy_dbcsr_to_fm,&
                                              cp_dbcsr_m_by_n_from_template,&
                                              cp_dbcsr_sm_fm_multiply,&
                                              cp_fm_to_dbcsr_row_template
   USE cp_fm_basic_linalg,              ONLY: cp_fm_column_scale,&
                                              cp_fm_triangular_invert,&
                                              cp_fm_triangular_multiply,&
                                              cp_fm_uplo_to_full
   USE cp_fm_cholesky,                  ONLY: cp_fm_cholesky_decompose,&
                                              cp_fm_cholesky_reduce,&
                                              cp_fm_cholesky_restore
   USE cp_fm_diag,                      ONLY: choose_eigv_solver
   USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
                                              cp_fm_struct_release,&
                                              cp_fm_struct_type
   USE cp_fm_types,                     ONLY: cp_fm_create,&
                                              cp_fm_get_diag,&
                                              cp_fm_get_info,&
                                              cp_fm_release,&
                                              cp_fm_to_fm,&
                                              cp_fm_type
   USE input_constants,                 ONLY: &
        cholesky_inverse, cholesky_reduce, ot_precond_full_all, ot_precond_full_kinetic, &
        ot_precond_full_single, ot_precond_full_single_inverse, ot_precond_s_inverse, &
        ot_precond_solver_default, ot_precond_solver_inv_chol
   USE kinds,                           ONLY: dp
   USE parallel_gemm_api,               ONLY: parallel_gemm
   USE preconditioner_types,            ONLY: preconditioner_type
#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'preconditioner_makes'

   PUBLIC :: make_preconditioner_matrix

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param preconditioner_env ...
!> \param matrix_h ...
!> \param matrix_s ...
!> \param matrix_t ...
!> \param mo_coeff ...
!> \param energy_homo ...
!> \param eigenvalues_ot ...
!> \param energy_gap ...
!> \param my_solver_type ...
! **************************************************************************************************
   SUBROUTINE make_preconditioner_matrix(preconditioner_env, matrix_h, matrix_s, matrix_t, mo_coeff, &
                                         energy_homo, eigenvalues_ot, energy_gap, &
                                         my_solver_type)
      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(dbcsr_type), POINTER                          :: matrix_h
      TYPE(dbcsr_type), OPTIONAL, POINTER                :: matrix_s, matrix_t
      TYPE(cp_fm_type), INTENT(IN)                       :: mo_coeff
      REAL(KIND=dp)                                      :: energy_homo
      REAL(KIND=dp), DIMENSION(:)                        :: eigenvalues_ot
      REAL(KIND=dp)                                      :: energy_gap
      INTEGER                                            :: my_solver_type

      INTEGER                                            :: precon_type

      precon_type = preconditioner_env%in_use
      SELECT CASE (precon_type)
      CASE (ot_precond_full_single)
         IF (my_solver_type /= ot_precond_solver_default) &
            CPABORT("Only PRECOND_SOLVER DEFAULT for the moment")
         IF (PRESENT(matrix_s)) THEN
            CALL make_full_single(preconditioner_env, preconditioner_env%fm, &
                                  matrix_h, matrix_s, energy_homo, energy_gap)
         ELSE
            CALL make_full_single_ortho(preconditioner_env, preconditioner_env%fm, &
                                        matrix_h, energy_homo, energy_gap)
         END IF

      CASE (ot_precond_s_inverse)
         IF (my_solver_type == ot_precond_solver_default) my_solver_type = ot_precond_solver_inv_chol
         IF (.NOT. PRESENT(matrix_s)) &
            CPABORT("Type for S=1 not implemented")
         CALL make_full_s_inverse(preconditioner_env, matrix_s)

      CASE (ot_precond_full_kinetic)
         IF (my_solver_type == ot_precond_solver_default) my_solver_type = ot_precond_solver_inv_chol
         IF (.NOT. (PRESENT(matrix_s) .AND. PRESENT(matrix_t))) &
            CPABORT("Type for S=1 not implemented")
         CALL make_full_kinetic(preconditioner_env, matrix_t, matrix_s, energy_gap)
      CASE (ot_precond_full_single_inverse)
         IF (my_solver_type == ot_precond_solver_default) my_solver_type = ot_precond_solver_inv_chol
         CALL make_full_single_inverse(preconditioner_env, mo_coeff, matrix_h, energy_gap, &
                                       matrix_s=matrix_s)
      CASE (ot_precond_full_all)
         IF (my_solver_type /= ot_precond_solver_default) THEN
            CPABORT("Only PRECOND_SOLVER DEFAULT for the moment")
         END IF
         IF (PRESENT(matrix_s)) THEN
            CALL make_full_all(preconditioner_env, mo_coeff, matrix_h, matrix_s, &
                               eigenvalues_ot, energy_gap)
         ELSE
            CALL make_full_all_ortho(preconditioner_env, mo_coeff, matrix_h, &
                                     eigenvalues_ot, energy_gap)
         END IF

      CASE DEFAULT
         CPABORT("Type not implemented")
      END SELECT

   END SUBROUTINE make_preconditioner_matrix

! **************************************************************************************************
!> \brief Simply takes the overlap matrix as preconditioner
!> \param preconditioner_env ...
!> \param matrix_s ...
! **************************************************************************************************
   SUBROUTINE make_full_s_inverse(preconditioner_env, matrix_s)
      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(dbcsr_type), POINTER                          :: matrix_s

      CHARACTER(len=*), PARAMETER :: routineN = 'make_full_s_inverse'

      INTEGER                                            :: handle

      CALL timeset(routineN, handle)

      CPASSERT(ASSOCIATED(matrix_s))

      IF (.NOT. ASSOCIATED(preconditioner_env%sparse_matrix)) THEN
         ALLOCATE (preconditioner_env%sparse_matrix)
      END IF
      CALL dbcsr_copy(preconditioner_env%sparse_matrix, matrix_s, name="full_kinetic")

      CALL timestop(handle)

   END SUBROUTINE make_full_s_inverse

! **************************************************************************************************
!> \brief kinetic matrix+shift*overlap as preconditioner. Cheap but could
!>        be better
!> \param preconditioner_env ...
!> \param matrix_t ...
!> \param matrix_s ...
!> \param energy_gap ...
! **************************************************************************************************
   SUBROUTINE make_full_kinetic(preconditioner_env, matrix_t, matrix_s, &
                                energy_gap)
      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(dbcsr_type), POINTER                          :: matrix_t, matrix_s
      REAL(KIND=dp)                                      :: energy_gap

      CHARACTER(len=*), PARAMETER                        :: routineN = 'make_full_kinetic'

      INTEGER                                            :: handle
      REAL(KIND=dp)                                      :: shift

      CALL timeset(routineN, handle)

      CPASSERT(ASSOCIATED(matrix_t))
      CPASSERT(ASSOCIATED(matrix_s))

      IF (.NOT. ASSOCIATED(preconditioner_env%sparse_matrix)) THEN
         ALLOCATE (preconditioner_env%sparse_matrix)
      END IF
      CALL dbcsr_copy(preconditioner_env%sparse_matrix, matrix_t, name="full_kinetic")

      shift = MAX(0.0_dp, energy_gap)

      CALL dbcsr_add(preconditioner_env%sparse_matrix, matrix_s, &
                     alpha_scalar=1.0_dp, beta_scalar=shift)

      CALL timestop(handle)

   END SUBROUTINE make_full_kinetic

! **************************************************************************************************
!> \brief full_single_preconditioner
!> \param preconditioner_env ...
!> \param fm ...
!> \param matrix_h ...
!> \param matrix_s ...
!> \param energy_homo ...
!> \param energy_gap ...
! **************************************************************************************************
   SUBROUTINE make_full_single(preconditioner_env, fm, matrix_h, matrix_s, &
                               energy_homo, energy_gap)
      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(cp_fm_type), POINTER                          :: fm
      TYPE(dbcsr_type), POINTER                          :: matrix_h, matrix_s
      REAL(KIND=dp)                                      :: energy_homo, energy_gap

      CHARACTER(len=*), PARAMETER                        :: routineN = 'make_full_single'

      INTEGER                                            :: handle, i, n
      REAL(KIND=dp), DIMENSION(:), POINTER               :: evals
      TYPE(cp_fm_struct_type), POINTER                   :: fm_struct_tmp
      TYPE(cp_fm_type)                                   :: fm_h, fm_s

      CALL timeset(routineN, handle)

      NULLIFY (fm_struct_tmp, evals)

      IF (ASSOCIATED(fm)) THEN
         CALL cp_fm_release(fm)
         DEALLOCATE (fm)
         NULLIFY (fm)
      END IF
      CALL dbcsr_get_info(matrix_h, nfullrows_total=n)
      ALLOCATE (evals(n))

      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=n, ncol_global=n, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      ALLOCATE (fm)
      CALL cp_fm_create(fm, fm_struct_tmp, name="preconditioner")
      CALL cp_fm_create(fm_h, fm_struct_tmp, name="fm_h")
      CALL cp_fm_create(fm_s, fm_struct_tmp, name="fm_s")
      CALL cp_fm_struct_release(fm_struct_tmp)

      CALL copy_dbcsr_to_fm(matrix_h, fm_h)
      CALL copy_dbcsr_to_fm(matrix_s, fm_s)
      CALL cp_fm_cholesky_decompose(fm_s)

      SELECT CASE (preconditioner_env%cholesky_use)
      CASE (cholesky_inverse)
! if cho inverse
         CALL cp_fm_triangular_invert(fm_s)
         CALL cp_fm_uplo_to_full(fm_h, fm)

         CALL cp_fm_triangular_multiply(fm_s, fm_h, side="R", transpose_tr=.FALSE., &
                                        invert_tr=.FALSE., uplo_tr="U", n_rows=n, n_cols=n, alpha=1.0_dp)
         CALL cp_fm_triangular_multiply(fm_s, fm_h, side="L", transpose_tr=.TRUE., &
                                        invert_tr=.FALSE., uplo_tr="U", n_rows=n, n_cols=n, alpha=1.0_dp)
      CASE (cholesky_reduce)
         CALL cp_fm_cholesky_reduce(fm_h, fm_s)
      CASE DEFAULT
         CPABORT("cholesky type not implemented")
      END SELECT

      CALL choose_eigv_solver(fm_h, fm, evals)

      SELECT CASE (preconditioner_env%cholesky_use)
      CASE (cholesky_inverse)
         CALL cp_fm_triangular_multiply(fm_s, fm, side="L", transpose_tr=.FALSE., &
                                        invert_tr=.FALSE., uplo_tr="U", n_rows=n, n_cols=n, alpha=1.0_dp)
         DO i = 1, n
            evals(i) = 1.0_dp/MAX(evals(i) - energy_homo, energy_gap)
         END DO
         CALL cp_fm_to_fm(fm, fm_h)
      CASE (cholesky_reduce)
         CALL cp_fm_cholesky_restore(fm, n, fm_s, fm_h, "SOLVE")
         DO i = 1, n
            evals(i) = 1.0_dp/MAX(evals(i) - energy_homo, energy_gap)
         END DO
         CALL cp_fm_to_fm(fm_h, fm)
      END SELECT

      CALL cp_fm_column_scale(fm, evals)
      CALL parallel_gemm('N', 'T', n, n, n, 1.0_dp, fm, fm_h, 0.0_dp, fm_s)
      CALL cp_fm_to_fm(fm_s, fm)

      DEALLOCATE (evals)
      CALL cp_fm_release(fm_h)
      CALL cp_fm_release(fm_s)

      CALL timestop(handle)

   END SUBROUTINE make_full_single

! **************************************************************************************************
!> \brief full single in the orthonormal basis
!> \param preconditioner_env ...
!> \param fm ...
!> \param matrix_h ...
!> \param energy_homo ...
!> \param energy_gap ...
! **************************************************************************************************
   SUBROUTINE make_full_single_ortho(preconditioner_env, fm, matrix_h, &
                                     energy_homo, energy_gap)
      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(cp_fm_type), POINTER                          :: fm
      TYPE(dbcsr_type), POINTER                          :: matrix_h
      REAL(KIND=dp)                                      :: energy_homo, energy_gap

      CHARACTER(len=*), PARAMETER :: routineN = 'make_full_single_ortho'

      INTEGER                                            :: handle, i, n
      REAL(KIND=dp), DIMENSION(:), POINTER               :: evals
      TYPE(cp_fm_struct_type), POINTER                   :: fm_struct_tmp
      TYPE(cp_fm_type)                                   :: fm_h, fm_s

      CALL timeset(routineN, handle)
      NULLIFY (fm_struct_tmp, evals)

      IF (ASSOCIATED(fm)) THEN
         CALL cp_fm_release(fm)
         DEALLOCATE (fm)
         NULLIFY (fm)
      END IF
      CALL dbcsr_get_info(matrix_h, nfullrows_total=n)
      ALLOCATE (evals(n))

      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=n, ncol_global=n, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      ALLOCATE (fm)
      CALL cp_fm_create(fm, fm_struct_tmp, name="preconditioner")
      CALL cp_fm_create(fm_h, fm_struct_tmp, name="fm_h")
      CALL cp_fm_create(fm_s, fm_struct_tmp, name="fm_s")
      CALL cp_fm_struct_release(fm_struct_tmp)

      CALL copy_dbcsr_to_fm(matrix_h, fm_h)

      CALL choose_eigv_solver(fm_h, fm, evals)
      DO i = 1, n
         evals(i) = 1.0_dp/MAX(evals(i) - energy_homo, energy_gap)
      END DO
      CALL cp_fm_to_fm(fm, fm_h)
      CALL cp_fm_column_scale(fm, evals)
      CALL parallel_gemm('N', 'T', n, n, n, 1.0_dp, fm, fm_h, 0.0_dp, fm_s)
      CALL cp_fm_to_fm(fm_s, fm)

      DEALLOCATE (evals)
      CALL cp_fm_release(fm_h)
      CALL cp_fm_release(fm_s)

      CALL timestop(handle)

   END SUBROUTINE make_full_single_ortho

! **************************************************************************************************
!> \brief generates a state by state preconditioner based on the full hamiltonian matrix
!> \param preconditioner_env ...
!> \param matrix_c0 ...
!> \param matrix_h ...
!> \param matrix_s ...
!> \param c0_evals ...
!> \param energy_gap should be a slight underestimate of the physical energy gap for almost all systems
!>      the c0 are already ritz states of (h,s)
!> \par History
!>      10.2006 made more stable [Joost VandeVondele]
!> \note
!>      includes error estimate on the hamiltonian matrix to result in a stable preconditioner
!>      a preconditioner for each eigenstate i is generated by keeping the factorized form
!>      U diag( something i ) U^T. It is important to only precondition in the subspace orthogonal to c0.
!>      not only is it the only part that matters, it also simplifies the computation of
!>      the lagrangian multipliers in the OT minimization  (i.e. if the c0 here is different
!>      from the c0 used in the OT setup, there will be a bug).
! **************************************************************************************************
   SUBROUTINE make_full_all(preconditioner_env, matrix_c0, matrix_h, matrix_s, c0_evals, energy_gap)
      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(cp_fm_type), INTENT(IN)                       :: matrix_c0
      TYPE(dbcsr_type), POINTER                          :: matrix_h, matrix_s
      REAL(KIND=dp), DIMENSION(:)                        :: c0_evals
      REAL(KIND=dp)                                      :: energy_gap

      CHARACTER(len=*), PARAMETER                        :: routineN = 'make_full_all'
      REAL(KIND=dp), PARAMETER                           :: fudge_factor = 0.25_dp, &
                                                            lambda_base = 10.0_dp

      INTEGER                                            :: handle, k, n
      REAL(KIND=dp)                                      :: error_estimate, lambda
      REAL(KIND=dp), DIMENSION(:), POINTER               :: diag, norms, shifted_evals
      TYPE(cp_fm_struct_type), POINTER                   :: fm_struct_tmp
      TYPE(cp_fm_type)                                   :: matrix_hc0, matrix_left, matrix_s1, &
                                                            matrix_s2, matrix_sc0, matrix_shc0, &
                                                            matrix_tmp, ortho
      TYPE(cp_fm_type), POINTER                          :: matrix_pre

      CALL timeset(routineN, handle)

      IF (ASSOCIATED(preconditioner_env%fm)) THEN
         CALL cp_fm_release(preconditioner_env%fm)
         DEALLOCATE (preconditioner_env%fm)
         NULLIFY (preconditioner_env%fm)
      END IF
      CALL cp_fm_get_info(matrix_c0, nrow_global=n, ncol_global=k)
      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=n, ncol_global=n, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      ALLOCATE (preconditioner_env%fm)
      CALL cp_fm_create(preconditioner_env%fm, fm_struct_tmp, name="preconditioner_env%fm")
      CALL cp_fm_create(ortho, fm_struct_tmp, name="ortho")
      CALL cp_fm_create(matrix_tmp, fm_struct_tmp, name="matrix_tmp")
      CALL cp_fm_struct_release(fm_struct_tmp)
      ALLOCATE (preconditioner_env%full_evals(n))
      ALLOCATE (preconditioner_env%occ_evals(k))

      ! 0) cholesky decompose the overlap matrix, if this fails the basis is singular,
      !    more than EPS_DEFAULT
      CALL copy_dbcsr_to_fm(matrix_s, ortho)
      CALL cp_fm_cholesky_decompose(ortho)
! if cho inverse
      IF (preconditioner_env%cholesky_use == cholesky_inverse) THEN
         CALL cp_fm_triangular_invert(ortho)
      END IF
      ! 1) Construct a new H matrix, which has the current C0 as eigenvectors,
      !    possibly shifted by an amount lambda,
      !    and the same spectrum as the original H matrix in the space orthogonal to the C0
      !    with P=C0 C0 ^ T
      !    (1 - PS)^T H (1-PS) + (PS)^T (H - lambda S ) (PS)
      !    we exploit that the C0 are already the ritz states of H
      CALL cp_fm_create(matrix_sc0, matrix_c0%matrix_struct, name="sc0")
      CALL cp_dbcsr_sm_fm_multiply(matrix_s, matrix_c0, matrix_sc0, k)
      CALL cp_fm_create(matrix_hc0, matrix_c0%matrix_struct, name="hc0")
      CALL cp_dbcsr_sm_fm_multiply(matrix_h, matrix_c0, matrix_hc0, k)

      ! An aside, try to estimate the error on the ritz values, we'll need it later on
      CALL cp_fm_create(matrix_shc0, matrix_c0%matrix_struct, name="shc0")

      SELECT CASE (preconditioner_env%cholesky_use)
      CASE (cholesky_inverse)
! if cho inverse
         CALL cp_fm_to_fm(matrix_hc0, matrix_shc0)
         CALL cp_fm_triangular_multiply(ortho, matrix_shc0, side="L", transpose_tr=.TRUE., &
                                        invert_tr=.FALSE., uplo_tr="U", n_rows=n, n_cols=k, alpha=1.0_dp)
      CASE (cholesky_reduce)
         CALL cp_fm_cholesky_restore(matrix_hc0, k, ortho, matrix_shc0, "SOLVE", transa="T")
      CASE DEFAULT
         CPABORT("cholesky type not implemented")
      END SELECT
      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=k, ncol_global=k, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      CALL cp_fm_create(matrix_s1, fm_struct_tmp, name="matrix_s1")
      CALL cp_fm_struct_release(fm_struct_tmp)
      ! since we only use diagonal elements this is a bit of a waste
      CALL parallel_gemm('T', 'N', k, k, n, 1.0_dp, matrix_shc0, matrix_shc0, 0.0_dp, matrix_s1)
      ALLOCATE (diag(k))
      CALL cp_fm_get_diag(matrix_s1, diag)
      error_estimate = MAXVAL(SQRT(ABS(diag - c0_evals**2)))
      DEALLOCATE (diag)
      CALL cp_fm_release(matrix_s1)
      CALL cp_fm_release(matrix_shc0)
      ! we'll only use the energy gap, if our estimate of the error on the eigenvalues
      ! is small enough. A large error combined with a small energy gap would otherwise lead to
      ! an aggressive but bad preconditioner. Only when the error is small (MD), we can precondition
      ! aggressively
      preconditioner_env%energy_gap = MAX(energy_gap, error_estimate*fudge_factor)
      CALL copy_dbcsr_to_fm(matrix_h, matrix_tmp)
      matrix_pre => preconditioner_env%fm
      CALL cp_fm_uplo_to_full(matrix_tmp, matrix_pre)
      ! tmp = H ( 1 - PS )
      CALL parallel_gemm('N', 'T', n, n, k, -1.0_dp, matrix_hc0, matrix_sc0, 1.0_dp, matrix_tmp)

      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=k, ncol_global=n, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      CALL cp_fm_create(matrix_left, fm_struct_tmp, name="matrix_left")
      CALL cp_fm_struct_release(fm_struct_tmp)
      CALL parallel_gemm('T', 'N', k, n, n, 1.0_dp, matrix_c0, matrix_tmp, 0.0_dp, matrix_left)
      ! tmp = (1 - PS)^T H (1-PS)
      CALL parallel_gemm('N', 'N', n, n, k, -1.0_dp, matrix_sc0, matrix_left, 1.0_dp, matrix_tmp)
      CALL cp_fm_release(matrix_left)

      ALLOCATE (shifted_evals(k))
      lambda = lambda_base + error_estimate
      shifted_evals = c0_evals - lambda
      CALL cp_fm_to_fm(matrix_sc0, matrix_hc0)
      CALL cp_fm_column_scale(matrix_hc0, shifted_evals)
      CALL parallel_gemm('N', 'T', n, n, k, 1.0_dp, matrix_hc0, matrix_sc0, 1.0_dp, matrix_tmp)

      ! 2) diagonalize this operator
      SELECT CASE (preconditioner_env%cholesky_use)
      CASE (cholesky_inverse)
         CALL cp_fm_triangular_multiply(ortho, matrix_tmp, side="R", transpose_tr=.FALSE., &
                                        invert_tr=.FALSE., uplo_tr="U", n_rows=n, n_cols=n, alpha=1.0_dp)
         CALL cp_fm_triangular_multiply(ortho, matrix_tmp, side="L", transpose_tr=.TRUE., &
                                        invert_tr=.FALSE., uplo_tr="U", n_rows=n, n_cols=n, alpha=1.0_dp)
      CASE (cholesky_reduce)
         CALL cp_fm_cholesky_reduce(matrix_tmp, ortho)
      END SELECT
      CALL choose_eigv_solver(matrix_tmp, matrix_pre, preconditioner_env%full_evals)
      SELECT CASE (preconditioner_env%cholesky_use)
      CASE (cholesky_inverse)
         CALL cp_fm_triangular_multiply(ortho, matrix_pre, side="L", transpose_tr=.FALSE., &
                                        invert_tr=.FALSE., uplo_tr="U", n_rows=n, n_cols=n, alpha=1.0_dp)
         CALL cp_fm_to_fm(matrix_pre, matrix_tmp)
      CASE (cholesky_reduce)
         CALL cp_fm_cholesky_restore(matrix_pre, n, ortho, matrix_tmp, "SOLVE")
         CALL cp_fm_to_fm(matrix_tmp, matrix_pre)
      END SELECT

      ! test that the subspace remained conserved
      IF (.FALSE.) THEN
         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=k, ncol_global=k, &
                                  context=preconditioner_env%ctxt, &
                                  para_env=preconditioner_env%para_env)
         CALL cp_fm_create(matrix_s1, fm_struct_tmp, name="matrix_s1")
         CALL cp_fm_create(matrix_s2, fm_struct_tmp, name="matrix_s2")
         CALL cp_fm_struct_release(fm_struct_tmp)
         ALLOCATE (norms(k))
         CALL parallel_gemm('T', 'N', k, k, n, 1.0_dp, matrix_sc0, matrix_tmp, 0.0_dp, matrix_s1)
         CALL choose_eigv_solver(matrix_s1, matrix_s2, norms)
         WRITE (*, *) "matrix norm deviation (should be close to zero): ", MAXVAL(ABS(ABS(norms) - 1.0_dp))
         DEALLOCATE (norms)
         CALL cp_fm_release(matrix_s1)
         CALL cp_fm_release(matrix_s2)
      END IF

      ! 3) replace the lowest k evals and evecs with what they should be
      preconditioner_env%occ_evals = c0_evals
      ! notice, this choice causes the preconditioner to be constant when applied to sc0 (see apply_full_all)
      preconditioner_env%full_evals(1:k) = c0_evals
      CALL cp_fm_to_fm(matrix_c0, matrix_pre, k, 1, 1)

      CALL cp_fm_release(matrix_sc0)
      CALL cp_fm_release(matrix_hc0)
      CALL cp_fm_release(ortho)
      CALL cp_fm_release(matrix_tmp)
      DEALLOCATE (shifted_evals)
      CALL timestop(handle)

   END SUBROUTINE make_full_all

! **************************************************************************************************
!> \brief full all in the orthonormal basis
!> \param preconditioner_env ...
!> \param matrix_c0 ...
!> \param matrix_h ...
!> \param c0_evals ...
!> \param energy_gap ...
! **************************************************************************************************
   SUBROUTINE make_full_all_ortho(preconditioner_env, matrix_c0, matrix_h, c0_evals, energy_gap)

      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(cp_fm_type), INTENT(IN)                       :: matrix_c0
      TYPE(dbcsr_type), POINTER                          :: matrix_h
      REAL(KIND=dp), DIMENSION(:)                        :: c0_evals
      REAL(KIND=dp)                                      :: energy_gap

      CHARACTER(len=*), PARAMETER :: routineN = 'make_full_all_ortho'
      REAL(KIND=dp), PARAMETER                           :: fudge_factor = 0.25_dp, &
                                                            lambda_base = 10.0_dp

      INTEGER                                            :: handle, k, n
      REAL(KIND=dp)                                      :: error_estimate, lambda
      REAL(KIND=dp), DIMENSION(:), POINTER               :: diag, norms, shifted_evals
      TYPE(cp_fm_struct_type), POINTER                   :: fm_struct_tmp
      TYPE(cp_fm_type)                                   :: matrix_hc0, matrix_left, matrix_s1, &
                                                            matrix_s2, matrix_sc0, matrix_tmp
      TYPE(cp_fm_type), POINTER                          :: matrix_pre

      CALL timeset(routineN, handle)

      IF (ASSOCIATED(preconditioner_env%fm)) THEN
         CALL cp_fm_release(preconditioner_env%fm)
         DEALLOCATE (preconditioner_env%fm)
         NULLIFY (preconditioner_env%fm)
      END IF
      CALL cp_fm_get_info(matrix_c0, nrow_global=n, ncol_global=k)
      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=n, ncol_global=n, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      ALLOCATE (preconditioner_env%fm)
      CALL cp_fm_create(preconditioner_env%fm, fm_struct_tmp, name="preconditioner_env%fm")
      CALL cp_fm_create(matrix_tmp, fm_struct_tmp, name="matrix_tmp")
      CALL cp_fm_struct_release(fm_struct_tmp)
      ALLOCATE (preconditioner_env%full_evals(n))
      ALLOCATE (preconditioner_env%occ_evals(k))

      ! 1) Construct a new H matrix, which has the current C0 as eigenvectors,
      !    possibly shifted by an amount lambda,
      !    and the same spectrum as the original H matrix in the space orthogonal to the C0
      !    with P=C0 C0 ^ T
      !    (1 - PS)^T H (1-PS) + (PS)^T (H - lambda S ) (PS)
      !    we exploit that the C0 are already the ritz states of H
      CALL cp_fm_create(matrix_sc0, matrix_c0%matrix_struct, name="sc0")
      CALL cp_fm_to_fm(matrix_c0, matrix_sc0)
      CALL cp_fm_create(matrix_hc0, matrix_c0%matrix_struct, name="hc0")
      CALL cp_dbcsr_sm_fm_multiply(matrix_h, matrix_c0, matrix_hc0, k)

      ! An aside, try to estimate the error on the ritz values, we'll need it later on
      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=k, ncol_global=k, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      CALL cp_fm_create(matrix_s1, fm_struct_tmp, name="matrix_s1")
      CALL cp_fm_struct_release(fm_struct_tmp)
      ! since we only use diagonal elements this is a bit of a waste
      CALL parallel_gemm('T', 'N', k, k, n, 1.0_dp, matrix_hc0, matrix_hc0, 0.0_dp, matrix_s1)
      ALLOCATE (diag(k))
      CALL cp_fm_get_diag(matrix_s1, diag)
      error_estimate = MAXVAL(SQRT(ABS(diag - c0_evals**2)))
      DEALLOCATE (diag)
      CALL cp_fm_release(matrix_s1)
      ! we'll only use the energy gap, if our estimate of the error on the eigenvalues
      ! is small enough. A large error combined with a small energy gap would otherwise lead to
      ! an aggressive but bad preconditioner. Only when the error is small (MD), we can precondition
      ! aggressively
      preconditioner_env%energy_gap = MAX(energy_gap, error_estimate*fudge_factor)

      matrix_pre => preconditioner_env%fm
      CALL copy_dbcsr_to_fm(matrix_h, matrix_tmp)
      CALL cp_fm_uplo_to_full(matrix_tmp, matrix_pre)
      ! tmp = H ( 1 - PS )
      CALL parallel_gemm('N', 'T', n, n, k, -1.0_dp, matrix_hc0, matrix_sc0, 1.0_dp, matrix_tmp)

      CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=k, ncol_global=n, &
                               context=preconditioner_env%ctxt, &
                               para_env=preconditioner_env%para_env)
      CALL cp_fm_create(matrix_left, fm_struct_tmp, name="matrix_left")
      CALL cp_fm_struct_release(fm_struct_tmp)
      CALL parallel_gemm('T', 'N', k, n, n, 1.0_dp, matrix_c0, matrix_tmp, 0.0_dp, matrix_left)
      ! tmp = (1 - PS)^T H (1-PS)
      CALL parallel_gemm('N', 'N', n, n, k, -1.0_dp, matrix_sc0, matrix_left, 1.0_dp, matrix_tmp)
      CALL cp_fm_release(matrix_left)

      ALLOCATE (shifted_evals(k))
      lambda = lambda_base + error_estimate
      shifted_evals = c0_evals - lambda
      CALL cp_fm_to_fm(matrix_sc0, matrix_hc0)
      CALL cp_fm_column_scale(matrix_hc0, shifted_evals)
      CALL parallel_gemm('N', 'T', n, n, k, 1.0_dp, matrix_hc0, matrix_sc0, 1.0_dp, matrix_tmp)

      ! 2) diagonalize this operator
      CALL choose_eigv_solver(matrix_tmp, matrix_pre, preconditioner_env%full_evals)

      ! test that the subspace remained conserved
      IF (.FALSE.) THEN
         CALL cp_fm_to_fm(matrix_pre, matrix_tmp)
         CALL cp_fm_struct_create(fm_struct_tmp, nrow_global=k, ncol_global=k, &
                                  context=preconditioner_env%ctxt, &
                                  para_env=preconditioner_env%para_env)
         CALL cp_fm_create(matrix_s1, fm_struct_tmp, name="matrix_s1")
         CALL cp_fm_create(matrix_s2, fm_struct_tmp, name="matrix_s2")
         CALL cp_fm_struct_release(fm_struct_tmp)
         ALLOCATE (norms(k))
         CALL parallel_gemm('T', 'N', k, k, n, 1.0_dp, matrix_sc0, matrix_tmp, 0.0_dp, matrix_s1)
         CALL choose_eigv_solver(matrix_s1, matrix_s2, norms)

         WRITE (*, *) "matrix norm deviation (should be close to zero): ", MAXVAL(ABS(ABS(norms) - 1.0_dp))
         DEALLOCATE (norms)
         CALL cp_fm_release(matrix_s1)
         CALL cp_fm_release(matrix_s2)
      END IF

      ! 3) replace the lowest k evals and evecs with what they should be
      preconditioner_env%occ_evals = c0_evals
      ! notice, this choice causes the preconditioner to be constant when applied to sc0 (see apply_full_all)
      preconditioner_env%full_evals(1:k) = c0_evals
      CALL cp_fm_to_fm(matrix_c0, matrix_pre, k, 1, 1)

      CALL cp_fm_release(matrix_sc0)
      CALL cp_fm_release(matrix_hc0)
      CALL cp_fm_release(matrix_tmp)
      DEALLOCATE (shifted_evals)

      CALL timestop(handle)

   END SUBROUTINE make_full_all_ortho

! **************************************************************************************************
!> \brief generates a preconditioner matrix H-lambda S+(SC)(2.0*CT*H*C+delta)(SC)^T
!>        for later inversion.
!>        H is the Kohn Sham matrix
!>        lambda*S shifts the spectrum of the generalized form up by lambda
!>        the last term only shifts the occupied space (reversing them in energy order)
!>        This form is implicitly multiplied from both sides by S^0.5
!>        This ensures we precondition the correct quantity
!>        Before this reads S^-0.5 H S^-0.5 + lambda + (S^0.5 C)shifts(S^0.5 C)T
!>        which might be a bit more obvious
!>        Replaced the old full_single_inverse at revision 14616
!> \param preconditioner_env the preconditioner env
!> \param matrix_c0 the MO coefficient matrix (fm)
!> \param matrix_h Kohn-Sham matrix (dbcsr)
!> \param energy_gap an additional shift in lambda=-E_homo+energy_gap
!> \param matrix_s the overlap matrix if not orthonormal (dbcsr, optional)
! **************************************************************************************************
   SUBROUTINE make_full_single_inverse(preconditioner_env, matrix_c0, matrix_h, energy_gap, matrix_s)
      TYPE(preconditioner_type)                          :: preconditioner_env
      TYPE(cp_fm_type), INTENT(IN)                       :: matrix_c0
      TYPE(dbcsr_type), POINTER                          :: matrix_h
      REAL(KIND=dp)                                      :: energy_gap
      TYPE(dbcsr_type), OPTIONAL, POINTER                :: matrix_s

      CHARACTER(len=*), PARAMETER :: routineN = 'make_full_single_inverse'

      INTEGER                                            :: handle, k, n
      REAL(KIND=dp)                                      :: max_ev, min_ev, pre_shift
      TYPE(arnoldi_env_type)                             :: arnoldi_env
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrices
      TYPE(dbcsr_type), TARGET                           :: dbcsr_cThc, dbcsr_hc, dbcsr_sc, mo_dbcsr

      CALL timeset(routineN, handle)

      ! Allocate all working matrices needed
      CALL cp_fm_get_info(matrix_c0, nrow_global=n, ncol_global=k)
      ! copy the fm MO's to a sparse matrix, can be solved better if the sparse version is already present
      ! but for the time beeing this will do
      CALL cp_fm_to_dbcsr_row_template(mo_dbcsr, matrix_c0, matrix_h)
      CALL dbcsr_create(dbcsr_sc, template=mo_dbcsr)
      CALL dbcsr_create(dbcsr_hc, template=mo_dbcsr)
      CALL cp_dbcsr_m_by_n_from_template(dbcsr_cThc, matrix_h, k, k, sym=dbcsr_type_symmetric)

      ! Check whether the output matrix was already created, if not do it now
      IF (.NOT. ASSOCIATED(preconditioner_env%sparse_matrix)) THEN
         ALLOCATE (preconditioner_env%sparse_matrix)
      END IF

      ! Put the first term of the preconditioner (H) into the output matrix
      CALL dbcsr_copy(preconditioner_env%sparse_matrix, matrix_h)

      ! Precompute some matrices
      ! S*C, if orthonormal this will be simply C so a copy will do
      IF (PRESENT(matrix_s)) THEN
         CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_s, mo_dbcsr, 0.0_dp, dbcsr_sc)
      ELSE
         CALL dbcsr_copy(dbcsr_sc, mo_dbcsr)
      END IF

!----------------------------compute the occupied subspace and shift it ------------------------------------
      ! cT*H*C which will be used to shift the occupied states to 0
      CALL dbcsr_multiply("N", "N", 1.0_dp, matrix_h, mo_dbcsr, 0.0_dp, dbcsr_hc)
      CALL dbcsr_multiply("T", "N", 1.0_dp, mo_dbcsr, dbcsr_hc, 0.0_dp, dbcsr_cThc)

      ! Compute the Energy of the HOMO. We will use this as a reference energy
      ALLOCATE (matrices(1))
      matrices(1)%matrix => dbcsr_cThc
      CALL setup_arnoldi_env(arnoldi_env, matrices, max_iter=20, threshold=1.0E-3_dp, selection_crit=2, &
                             nval_request=1, nrestarts=8, generalized_ev=.FALSE., iram=.FALSE.)
      IF (ASSOCIATED(preconditioner_env%max_ev_vector)) &
         CALL set_arnoldi_initial_vector(arnoldi_env, preconditioner_env%max_ev_vector)
      CALL arnoldi_ev(matrices, arnoldi_env)
      max_ev = REAL(get_selected_ritz_val(arnoldi_env, 1), dp)

      ! save the ev as guess for the next time
      IF (.NOT. ASSOCIATED(preconditioner_env%max_ev_vector)) ALLOCATE (preconditioner_env%max_ev_vector)
      CALL get_selected_ritz_vector(arnoldi_env, 1, matrices(1)%matrix, preconditioner_env%max_ev_vector)
      CALL deallocate_arnoldi_env(arnoldi_env)
      DEALLOCATE (matrices)

      ! Lets shift the occupied states a bit further up, -1.0 because we gonna subtract it from H
      CALL dbcsr_add_on_diag(dbcsr_cThc, -0.5_dp)
      ! Get the AO representation of the shift (see above why S is needed), W-matrix like object
      CALL dbcsr_multiply("N", "N", 2.0_dp, dbcsr_sc, dbcsr_cThc, 0.0_dp, dbcsr_hc)
      CALL dbcsr_multiply("N", "T", -1.0_dp, dbcsr_hc, dbcsr_sc, 1.0_dp, preconditioner_env%sparse_matrix)

!-------------------------------------compute eigenvalues of H ----------------------------------------------
      ! Setup the arnoldi procedure to compute the lowest ev. if S is present this has to be the generalized ev
      IF (PRESENT(matrix_s)) THEN
         ALLOCATE (matrices(2))
         matrices(1)%matrix => preconditioner_env%sparse_matrix
         matrices(2)%matrix => matrix_s
         CALL setup_arnoldi_env(arnoldi_env, matrices, max_iter=20, threshold=2.0E-2_dp, selection_crit=3, &
                                nval_request=1, nrestarts=21, generalized_ev=.TRUE., iram=.FALSE.)
      ELSE
         ALLOCATE (matrices(1))
         matrices(1)%matrix => preconditioner_env%sparse_matrix
         CALL setup_arnoldi_env(arnoldi_env, matrices, max_iter=20, threshold=2.0E-2_dp, selection_crit=3, &
                                nval_request=1, nrestarts=8, generalized_ev=.FALSE., iram=.FALSE.)
      END IF
      IF (ASSOCIATED(preconditioner_env%min_ev_vector)) &
         CALL set_arnoldi_initial_vector(arnoldi_env, preconditioner_env%min_ev_vector)

      ! compute the LUMO energy
      CALL arnoldi_ev(matrices, arnoldi_env)
      min_eV = REAL(get_selected_ritz_val(arnoldi_env, 1), dp)

      ! save the lumo vector for restarting in the next step
      IF (.NOT. ASSOCIATED(preconditioner_env%min_ev_vector)) ALLOCATE (preconditioner_env%min_ev_vector)
      CALL get_selected_ritz_vector(arnoldi_env, 1, matrices(1)%matrix, preconditioner_env%min_ev_vector)
      CALL deallocate_arnoldi_env(arnoldi_env)
      DEALLOCATE (matrices)

!-------------------------------------compute eigenvalues of H ----------------------------------------------
      ! Shift the Lumo to the 1.5*the computed energy_gap or the external energy gap value
      ! The factor 1.5 is determined by trying. If the LUMO is positive, enough, just leave it alone
      pre_shift = MAX(1.5_dp*(min_ev - max_ev), energy_gap)
      IF (min_ev < pre_shift) THEN
         pre_shift = pre_shift - min_ev
      ELSE
         pre_shift = 0.0_dp
      END IF
      IF (PRESENT(matrix_s)) THEN
         CALL dbcsr_add(preconditioner_env%sparse_matrix, matrix_s, 1.0_dp, pre_shift)
      ELSE
         CALL dbcsr_add_on_diag(preconditioner_env%sparse_matrix, pre_shift)
      END IF

      CALL dbcsr_release(mo_dbcsr)
      CALL dbcsr_release(dbcsr_hc)
      CALL dbcsr_release(dbcsr_sc)
      CALL dbcsr_release(dbcsr_cThc)

      CALL timestop(handle)

   END SUBROUTINE make_full_single_inverse

END MODULE preconditioner_makes

